Skip to content

Dump activation shardings#3080

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:charlesli/input_sharding
Feb 26, 2026
Merged

Dump activation shardings#3080
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:charlesli/input_sharding

Conversation

@charlesli640
Copy link
Collaborator

@charlesli640 charlesli640 commented Feb 4, 2026

Description

To dump activation shardings to golden file for further comparison. It can include in unit test in case further code change touches activation shardings. This PR is the initial submission for sharding dump json files.

Output

The output format is readable and comparable by both human and machine. For exampletests deepseek2-16b/v5p-16/slice_1 activation dump as below

{
  "Activation Sharding Dump": [
    {
      "deepseek/inputs: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "deepseek/pre_attention_norm: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "attention_mla/inputs_q: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "attention_mla/inputs_kv: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "attention_mla/q_nope: bfloat16[96,2048,16,128]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/q_pe: bfloat16[96,2048,16,64]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/query: bfloat16[96,2048,16,192]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/key_nope: bfloat16[96,2048,16,128]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/key_rope: bfloat16[96,2048,16,64]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/key: bfloat16[96,2048,16,192]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/value: bfloat16[96,2048,16,128]": {
        "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_op/query: bfloat16[96,16,2048,192]": {
        "logic_axes": "Unknown",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_op/key: bfloat16[96,16,2048,192]": {
        "logic_axes": "Unknown",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_op/value: bfloat16[96,16,2048,128]": {
        "logic_axes": "Unknown",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "attention_mla/out: bfloat16[96,2048,16,128]": {
        "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')",
        "PartitionSpec": "P('fsdp', None, None, None)"
      }
    },
    {
      "deepseek/attention_result: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "deepseek/post_attention_norm: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "linears/x: bfloat16[96,2048,10944]": {
        "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "deepseek/mlp: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "deepseek/x: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "moe/inputs: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', None)",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "moe/gate_logits: bfloat16[96,2048,64]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', None)",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "linears/x: bfloat16[96,2048,2816]": {
        "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    },
    {
      "deepseek/mlp_lnx: bfloat16[96,2048,2048]": {
        "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')",
        "PartitionSpec": "P('fsdp', None, None)"
      }
    }
  ]
}

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 4, 2026

Codecov Report

❌ Patch coverage is 84.21053% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/sharding.py 80.00% 3 Missing and 3 partials ⚠️

📢 Thoughts on this report? Let us know!

@gobbleturk
Copy link
Collaborator

I think this LGTM although there are a lot of names to review! How did you generate these names?

@charlesli640 charlesli640 marked this pull request as draft February 5, 2026 01:01
@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch from 4b17fdb to 511be4b Compare February 5, 2026 17:59
@charlesli640
Copy link
Collaborator Author

I think this LGTM although there are a lot of names to review! How did you generate these names?

These names are generated from local <file_name>/<variable_name>. Sometimes it may not correctly reflect the actual model/layer, but it is basically serving as an identifier/key for logging/dumping/comparing purpose.

@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch 5 times, most recently from 8071699 to 486ebfb Compare February 10, 2026 18:49
@charlesli640 charlesli640 marked this pull request as draft February 12, 2026 23:10
@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch 6 times, most recently from 660b637 to 504c66a Compare February 19, 2026 18:25
@charlesli640 charlesli640 marked this pull request as ready for review February 19, 2026 18:25
@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch from 504c66a to afc474b Compare February 19, 2026 18:29
Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Charles! Just some minor comments

@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch 12 times, most recently from 31860b0 to 5a1e8c9 Compare February 25, 2026 17:41
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Using inspect to get call stacktrace

Cmd to generate input_shardings.json files:
  python -m tests.utils.run_sharding_dump
@charlesli640 charlesli640 force-pushed the charlesli/input_sharding branch from d7a2f5f to d4bc454 Compare February 26, 2026 19:06
@copybara-service copybara-service bot merged commit 0b6a8d3 into AI-Hypercomputer:main Feb 26, 2026
44 of 49 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants